{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 用 2DCNN 来构建一个文本识别器 - From kaggle Quora Kernels:[2DCNN_textClassifiter](https://www.kaggle.com/yekenot/2dcnn-textclassifier)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
    "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "np.random.seed(42)\n",
    "import pandas as pd\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import f1_score\n",
    "from keras.models import Model\n",
    "from keras.layers import Input, Embedding, Dense, Conv2D, MaxPool2D\n",
    "from keras.layers import Reshape, Flatten, Concatenate, Dropout, SpatialDropout1D\n",
    "from keras.preprocessing import text, sequence\n",
    "from keras.callbacks import Callback\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>qid</th>\n",
       "      <th>question_text</th>\n",
       "      <th>target</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>00002165364db923c7e6</td>\n",
       "      <td>How did Quebec nationalists see their province...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>000032939017120e6e44</td>\n",
       "      <td>Do you have an adopted dog, how would you enco...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0000412ca6e4628ce2cf</td>\n",
       "      <td>Why does velocity affect time? Does velocity a...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>000042bf85aa498cd78e</td>\n",
       "      <td>How did Otto von Guericke used the Magdeburg h...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0000455dfa3e01eae3af</td>\n",
       "      <td>Can I convert montra helicon D to a mountain b...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>00004f9a462a357c33be</td>\n",
       "      <td>Is Gaza slowly becoming Auschwitz, Dachau or T...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>00005059a06ee19e11ad</td>\n",
       "      <td>Why does Quora automatically ban conservative ...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>0000559f875832745e2e</td>\n",
       "      <td>Is it crazy if I wash or wipe my groceries off...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>00005bd3426b2d0c8305</td>\n",
       "      <td>Is there such a thing as dressing moderately, ...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>00006e6928c5df60eacb</td>\n",
       "      <td>Is it just me or have you ever been in this ph...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                    qid                                      question_text  \\\n",
       "0  00002165364db923c7e6  How did Quebec nationalists see their province...   \n",
       "1  000032939017120e6e44  Do you have an adopted dog, how would you enco...   \n",
       "2  0000412ca6e4628ce2cf  Why does velocity affect time? Does velocity a...   \n",
       "3  000042bf85aa498cd78e  How did Otto von Guericke used the Magdeburg h...   \n",
       "4  0000455dfa3e01eae3af  Can I convert montra helicon D to a mountain b...   \n",
       "5  00004f9a462a357c33be  Is Gaza slowly becoming Auschwitz, Dachau or T...   \n",
       "6  00005059a06ee19e11ad  Why does Quora automatically ban conservative ...   \n",
       "7  0000559f875832745e2e  Is it crazy if I wash or wipe my groceries off...   \n",
       "8  00005bd3426b2d0c8305  Is there such a thing as dressing moderately, ...   \n",
       "9  00006e6928c5df60eacb  Is it just me or have you ever been in this ph...   \n",
       "\n",
       "   target  \n",
       "0       0  \n",
       "1       0  \n",
       "2       0  \n",
       "3       0  \n",
       "4       0  \n",
       "5       0  \n",
       "6       0  \n",
       "7       0  \n",
       "8       0  \n",
       "9       0  "
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train = pd.read_csv('./train.csv')\n",
    "train.head(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>qid</th>\n",
       "      <th>question_text</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>00014894849d00ba98a9</td>\n",
       "      <td>My voice range is A2-C5. My chest voice goes u...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>000156468431f09b3cae</td>\n",
       "      <td>How much does a tutor earn in Bangalore?</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>000227734433360e1aae</td>\n",
       "      <td>What are the best made pocket knives under $20...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0005e06fbe3045bd2a92</td>\n",
       "      <td>Why would they add a hypothetical scenario tha...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>00068a0f7f41f50fc399</td>\n",
       "      <td>What is the dresscode for Techmahindra freshers?</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>000a2d30e3ffd70c070d</td>\n",
       "      <td>How well are you adapting to the Trump era?</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>000b67672ec9622ff761</td>\n",
       "      <td>What should be the last thing people do in life?</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>000b7fb1146d712c1105</td>\n",
       "      <td>Received conditional offer for Masters in Inte...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>000d665a8ddc426a1907</td>\n",
       "      <td>What does appareils photo mean in French?</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>000df6fd2229447b2969</td>\n",
       "      <td>Is there a system of Public Interest Litigatio...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                    qid                                      question_text\n",
       "0  00014894849d00ba98a9  My voice range is A2-C5. My chest voice goes u...\n",
       "1  000156468431f09b3cae           How much does a tutor earn in Bangalore?\n",
       "2  000227734433360e1aae  What are the best made pocket knives under $20...\n",
       "3  0005e06fbe3045bd2a92  Why would they add a hypothetical scenario tha...\n",
       "4  00068a0f7f41f50fc399   What is the dresscode for Techmahindra freshers?\n",
       "5  000a2d30e3ffd70c070d        How well are you adapting to the Trump era?\n",
       "6  000b67672ec9622ff761   What should be the last thing people do in life?\n",
       "7  000b7fb1146d712c1105  Received conditional offer for Masters in Inte...\n",
       "8  000d665a8ddc426a1907          What does appareils photo mean in French?\n",
       "9  000df6fd2229447b2969  Is there a system of Public Interest Litigatio..."
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test = pd.read_csv('./test.csv')\n",
    "test.head(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "submission = pd.read_csv('./sample_submission.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "_uuid": "b60b0950ddf9fce1c088bf2f89d07ce04c3a82e1"
   },
   "outputs": [],
   "source": [
    "EMBEDDING_FILE = './glove.840B.300d/glove.840B.300d.txt'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['How did Quebec nationalists see their province as a nation in the 1960s?',\n",
       "       'Do you have an adopted dog, how would you encourage people to adopt and not shop?',\n",
       "       'Why does velocity affect time? Does velocity affect space geometry?',\n",
       "       'How did Otto von Guericke used the Magdeburg hemispheres?',\n",
       "       'Can I convert montra helicon D to a mountain bike by just changing the tyres?',\n",
       "       'Is Gaza slowly becoming Auschwitz, Dachau or Treblinka for Palestinians?',\n",
       "       'Why does Quora automatically ban conservative opinions when reported, but does not do the same for liberal views?',\n",
       "       'Is it crazy if I wash or wipe my groceries off? Germs are everywhere.',\n",
       "       'Is there such a thing as dressing moderately, and if so, how is that different than dressing modestly?',\n",
       "       'Is it just me or have you ever been in this phase wherein you became ignorant to the people you once loved, completely disregarding their feelings/lives so you get to have something go your way and feel temporarily at ease. How did things change?'],\n",
       "      dtype=object)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train[\"question_text\"][:10].fillna(\"fillna\").values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "_uuid": "784c915cb895f5c7586d4dd61d99d70a3c9a5651"
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "    这里的主要想法是：\n",
    "    1. 先获得训练集和测试集的Token，该Token的最大长度为40000。\n",
    "    2. 将获得的Token转换为序列\n",
    "    3. 再 pad 下每个序列的长度，最大长度为50\n",
    "    4. 同时设定 embed_size 为 300\n",
    "\"\"\"\n",
    "\n",
    "\n",
    "X_train = train[\"question_text\"].fillna(\"fillna\").values\n",
    "y_train = train[\"target\"].values\n",
    "X_test = test[\"question_text\"].fillna(\"fillna\").values\n",
    "\n",
    "max_features = 40000\n",
    "maxlen = 50\n",
    "embed_size = 300\n",
    "\n",
    "tokenizer = text.Tokenizer(num_words=max_features)\n",
    "# fit_on_texts():根据文本列表更新内部词汇表\n",
    "tokenizer.fit_on_texts(list(X_train) + list(X_test))\n",
    "# texts_to_sequences():将文本中的每个文本转换为整数序列\n",
    "X_train = tokenizer.texts_to_sequences(X_train)\n",
    "X_test = tokenizer.texts_to_sequences(X_test)\n",
    "# pad_sequences():填充序列使变成相同的长度\n",
    "x_train = sequence.pad_sequences(X_train, maxlen=maxlen)\n",
    "x_test = sequence.pad_sequences(X_test, maxlen=maxlen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "_uuid": "f2ee3631454ab235e7951a2313f11100bf7b7d97"
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "    这里主要是获得 embedding 矩阵\n",
    "\"\"\"\n",
    "\n",
    "def get_coefs(word, *arr): return word, np.asarray(arr, dtype='float32')\n",
    "embeddings_index = dict(get_coefs(*o.rstrip().rsplit(' ')) for o in open(EMBEDDING_FILE))\n",
    "\n",
    "# word_index:字典，将单词（字符串）映射为它们的排名或者索引\n",
    "word_index = tokenizer.word_index\n",
    "nb_words = min(max_features, len(word_index))\n",
    "embedding_matrix = np.zeros((nb_words, embed_size))\n",
    "for word, i in word_index.items():\n",
    "    if i >= max_features: continue\n",
    "    embedding_vector = embeddings_index.get(word)\n",
    "    if embedding_vector is not None: embedding_matrix[i] = embedding_vector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "_uuid": "bfd80e14218851850452b086565d4f8414a0517d"
   },
   "outputs": [],
   "source": [
    "class F1Evaluation(Callback):\n",
    "    def __init__(self, validation_data=(), interval=1):\n",
    "        super(Callback, self).__init__()\n",
    "\n",
    "        self.interval = interval\n",
    "        self.X_val, self.y_val = validation_data\n",
    "\n",
    "    def on_epoch_end(self, epoch, logs={}):\n",
    "        if epoch % self.interval == 0:\n",
    "            y_pred = self.model.predict(self.X_val, verbose=0)\n",
    "            y_pred = (y_pred > 0.5).astype(int)\n",
    "            score = f1_score(self.y_val, y_pred)\n",
    "            print(\"\\n F1 Score - epoch: %d - score: %.6f \\n\" % (epoch+1, score))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "_uuid": "7a3fed14c968b22e18f6c6c855ab9c23b9f0cce3"
   },
   "outputs": [],
   "source": [
    "filter_sizes = [1,2,3,5]\n",
    "num_filters = 35\n",
    "\n",
    "def get_model():    \n",
    "    inp = Input(shape=(maxlen, ))\n",
    "    x = Embedding(max_features, embed_size, weights=[embedding_matrix])(inp)\n",
    "    # SpatialDropout1D 与 Dropout 的作用类似，但它断开的是整个1D特征图，而不是单个神经元\n",
    "    x = SpatialDropout1D(0.4)(x)\n",
    "    x = Reshape((maxlen, embed_size, 1))(x)\n",
    "    \n",
    "    conv_0 = Conv2D(num_filters, kernel_size=(filter_sizes[0], embed_size),\n",
    "                                 kernel_initializer='normal', activation='elu')(x)\n",
    "    conv_1 = Conv2D(num_filters, kernel_size=(filter_sizes[1], embed_size),\n",
    "                                 kernel_initializer='normal', activation='elu')(x)\n",
    "    conv_2 = Conv2D(num_filters, kernel_size=(filter_sizes[2], embed_size), \n",
    "                                 kernel_initializer='normal', activation='elu')(x)\n",
    "    conv_3 = Conv2D(num_filters, kernel_size=(filter_sizes[3], embed_size),\n",
    "                                 kernel_initializer='normal', activation='elu')(x)\n",
    "    \n",
    "    maxpool_0 = MaxPool2D(pool_size=(maxlen - filter_sizes[0] + 1, 1))(conv_0)\n",
    "    maxpool_1 = MaxPool2D(pool_size=(maxlen - filter_sizes[1] + 1, 1))(conv_1)\n",
    "    maxpool_2 = MaxPool2D(pool_size=(maxlen - filter_sizes[2] + 1, 1))(conv_2)\n",
    "    maxpool_3 = MaxPool2D(pool_size=(maxlen - filter_sizes[3] + 1, 1))(conv_3)\n",
    "    \n",
    "    # Concatenate(axis=-1)该层接收一个列表的同shape张量，并返回它们的按照给定轴相接构成的向量。\n",
    "    z = Concatenate(axis=1)([maxpool_0, maxpool_1, maxpool_2, maxpool_3])   \n",
    "    z = Flatten()(z)\n",
    "    z = Dropout(0.1)(z)\n",
    "        \n",
    "    outp = Dense(1, activation=\"sigmoid\")(z)\n",
    "    \n",
    "    model = Model(inputs=inp, outputs=outp)\n",
    "    model.compile(loss='binary_crossentropy',\n",
    "                  optimizer='adam',\n",
    "                  metrics=['accuracy'])\n",
    "\n",
    "    return model\n",
    "\n",
    "model = get_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "_uuid": "51deacab40f77c038525faf1cfd2c8e820f7f01b"
   },
   "outputs": [],
   "source": [
    "batch_size = 256\n",
    "epochs = 3\n",
    "\n",
    "X_tra, X_val, y_tra, y_val = train_test_split(x_train, y_train, train_size=0.95,\n",
    "                                              random_state=233)\n",
    "F1_Score = F1Evaluation(validation_data=(X_val, y_val), interval=1)\n",
    "\n",
    "hist = model.fit(X_tra, y_tra, batch_size=batch_size, epochs=epochs,\n",
    "                 validation_data=(X_val, y_val),\n",
    "                 callbacks=[F1_Score], verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "_uuid": "9607e11c2591d51bb31da729c9934a9a69de854e"
   },
   "outputs": [],
   "source": [
    "y_pred = model.predict(x_test, batch_size=1024)\n",
    "y_pred = (y_pred > 0.5).astype(int)\n",
    "submission['prediction'] = y_pred\n",
    "submission.to_csv('submission.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "_uuid": "7685bcc3e67a411ac41c6f96ef393016cee83a1f"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
